"""Example training a memory neural net on the bAbI dataset.

References Keras and is based off of https://keras.io/examples/babi_memnn/.
"""

from __future__ import print_function

import argparse
import os
import re
import sys
import tarfile

import numpy as np
from filelock import FileLock

from ray import tune

if sys.version_info >= (3, 12):
    # Skip this test in Python 3.12+ because TensorFlow is not supported.
    sys.exit(0)
else:
    from tensorflow.keras.layers import (
        LSTM,
        Activation,
        Dense,
        Dropout,
        Embedding,
        Input,
        Permute,
        add,
        concatenate,
        dot,
    )
    from tensorflow.keras.models import Model, Sequential, load_model
    from tensorflow.keras.optimizers import RMSprop
    from tensorflow.keras.preprocessing.sequence import pad_sequences
    from tensorflow.keras.utils import get_file


def tokenize(sent):
    """Return the tokens of a sentence including punctuation.

    >>> tokenize("Bob dropped the apple. Where is the apple?")
    ["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
    """
    return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]


def parse_stories(lines, only_supporting=False):
    """Parse stories provided in the bAbi tasks format

    If only_supporting is true, only the sentences
    that support the answer are kept.
    """
    data = []
    story = []
    for line in lines:
        line = line.decode("utf-8").strip()
        nid, line = line.split(" ", 1)
        nid = int(nid)
        if nid == 1:
            story = []
        if "\t" in line:
            q, a, supporting = line.split("\t")
            q = tokenize(q)
            if only_supporting:
                # Only select the related substory
                supporting = map(int, supporting.split())
                substory = [story[i - 1] for i in supporting]
            else:
                # Provide all the substories
                substory = [x for x in story if x]
            data.append((substory, q, a))
            story.append("")
        else:
            sent = tokenize(line)
            story.append(sent)
    return data


def get_stories(f, only_supporting=False, max_length=None):
    """Given a file name, read the file,
    retrieve the stories,
    and then convert the sentences into a single story.

    If max_length is supplied,
    any stories longer than max_length tokens will be discarded.
    """

    def flatten(data):
        return sum(data, [])

    data = parse_stories(f.readlines(), only_supporting=only_supporting)
    data = [
        (flatten(story), q, answer)
        for story, q, answer in data
        if not max_length or len(flatten(story)) < max_length
    ]
    return data


def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
    inputs, queries, answers = [], [], []
    for story, query, answer in data:
        inputs.append([word_idx[w] for w in story])
        queries.append([word_idx[w] for w in query])
        answers.append(word_idx[answer])
    return (
        pad_sequences(inputs, maxlen=story_maxlen),
        pad_sequences(queries, maxlen=query_maxlen),
        np.array(answers),
    )


def read_data(finish_fast=False):
    # Get the file
    try:
        path = get_file(
            "babi-tasks-v1-2.tar.gz",
            origin="https://s3.amazonaws.com/text-datasets/"
            "babi_tasks_1-20_v1-2.tar.gz",
        )
    except Exception:
        print(
            "Error downloading dataset, please download it manually:\n"
            "$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2"  # noqa: E501
            ".tar.gz\n"
            "$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz"  # noqa: E501
        )
        raise

    # Choose challenge
    challenges = {
        # QA1 with 10,000 samples
        "single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
        "single-supporting-fact_{}.txt",
        # QA2 with 10,000 samples
        "two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
        "two-supporting-facts_{}.txt",
    }
    challenge_type = "single_supporting_fact_10k"
    challenge = challenges[challenge_type]

    with tarfile.open(path) as tar:
        train_stories = get_stories(tar.extractfile(challenge.format("train")))
        test_stories = get_stories(tar.extractfile(challenge.format("test")))
    if finish_fast:
        train_stories = train_stories[:64]
        test_stories = test_stories[:64]
    return train_stories, test_stories


class MemNNModel(tune.Trainable):
    def build_model(self):
        """Helper method for creating the model"""
        vocab = set()
        for story, q, answer in self.train_stories + self.test_stories:
            vocab |= set(story + q + [answer])
        vocab = sorted(vocab)

        # Reserve 0 for masking via pad_sequences
        vocab_size = len(vocab) + 1
        story_maxlen = max(len(x) for x, _, _ in self.train_stories + self.test_stories)
        query_maxlen = max(len(x) for _, x, _ in self.train_stories + self.test_stories)

        word_idx = {c: i + 1 for i, c in enumerate(vocab)}
        self.inputs_train, self.queries_train, self.answers_train = vectorize_stories(
            word_idx, story_maxlen, query_maxlen, self.train_stories
        )
        self.inputs_test, self.queries_test, self.answers_test = vectorize_stories(
            word_idx, story_maxlen, query_maxlen, self.test_stories
        )

        # placeholders
        input_sequence = Input((story_maxlen,))
        question = Input((query_maxlen,))

        # encoders
        # embed the input sequence into a sequence of vectors
        input_encoder_m = Sequential()
        input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
        input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, story_maxlen, embedding_dim)

        # embed the input into a sequence of vectors of size query_maxlen
        input_encoder_c = Sequential()
        input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen))
        input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, story_maxlen, query_maxlen)

        # embed the question into a sequence of vectors
        question_encoder = Sequential()
        question_encoder.add(
            Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen)
        )
        question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
        # output: (samples, query_maxlen, embedding_dim)

        # encode input sequence and questions (which are indices)
        # to sequences of dense vectors
        input_encoded_m = input_encoder_m(input_sequence)
        input_encoded_c = input_encoder_c(input_sequence)
        question_encoded = question_encoder(question)

        # compute a "match" between the first input vector sequence
        # and the question vector sequence
        # shape: `(samples, story_maxlen, query_maxlen)`
        match = dot([input_encoded_m, question_encoded], axes=(2, 2))
        match = Activation("softmax")(match)

        # add the match matrix with the second input vector sequence
        response = add(
            [match, input_encoded_c]
        )  # (samples, story_maxlen, query_maxlen)
        response = Permute((2, 1))(response)  # (samples, query_maxlen, story_maxlen)

        # concatenate the match matrix with the question vector sequence
        answer = concatenate([response, question_encoded])

        # the original paper uses a matrix multiplication.
        # we choose to use a RNN instead.
        answer = LSTM(32)(answer)  # (samples, 32)

        # one regularization layer -- more would probably be needed.
        answer = Dropout(self.config.get("dropout", 0.3))(answer)
        answer = Dense(vocab_size)(answer)  # (samples, vocab_size)
        # we output a probability distribution over the vocabulary
        answer = Activation("softmax")(answer)

        # build the final model
        model = Model([input_sequence, question], answer)
        return model

    def setup(self, config):
        with FileLock(os.path.expanduser("~/.tune.lock")):
            self.train_stories, self.test_stories = read_data(config["finish_fast"])
        model = self.build_model()
        rmsprop = RMSprop(
            lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9)
        )
        model.compile(
            optimizer=rmsprop,
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"],
        )
        self.model = model

    def step(self):
        # train
        self.model.fit(
            [self.inputs_train, self.queries_train],
            self.answers_train,
            batch_size=self.config.get("batch_size", 32),
            epochs=self.config.get("epochs", 1),
            validation_data=([self.inputs_test, self.queries_test], self.answers_test),
            verbose=0,
        )
        _, accuracy = self.model.evaluate(
            [self.inputs_train, self.queries_train], self.answers_train, verbose=0
        )
        return {"mean_accuracy": accuracy}

    def save_checkpoint(self, checkpoint_dir):
        file_path = checkpoint_dir + "/model"
        self.model.save(file_path)

    def load_checkpoint(self, checkpoint_dir):
        # See https://stackoverflow.com/a/42763323
        del self.model
        file_path = checkpoint_dir + "/model"
        self.model = load_model(file_path)


if __name__ == "__main__":
    import ray
    from ray.tune.schedulers import PopulationBasedTraining

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        ray.init(num_cpus=2)

    perturbation_interval = 2
    pbt = PopulationBasedTraining(
        perturbation_interval=perturbation_interval,
        hyperparam_mutations={
            "dropout": lambda: np.random.uniform(0, 1),
            "lr": lambda: 10 ** np.random.randint(-10, 0),
            "rho": lambda: np.random.uniform(0, 1),
        },
    )

    tuner = tune.Tuner(
        MemNNModel,
        run_config=tune.RunConfig(
            name="pbt_babi_memnn",
            stop={"training_iteration": 4 if args.smoke_test else 100},
            checkpoint_config=tune.CheckpointConfig(
                checkpoint_frequency=perturbation_interval,
                checkpoint_score_attribute="mean_accuracy",
                num_to_keep=2,
            ),
        ),
        tune_config=tune.TuneConfig(
            scheduler=pbt,
            metric="mean_accuracy",
            mode="max",
            num_samples=2,
            reuse_actors=True,
        ),
        param_space={
            "finish_fast": args.smoke_test,
            "batch_size": 32,
            "epochs": 1,
            "dropout": 0.3,
            "lr": 0.01,
            "rho": 0.9,
        },
    )
    tuner.fit()
